"""
Merge patent to patent similarity scores between min_year and max_year
Robustness check: Validation using cosine similarity as a continuous cutoff
2023-06-11
"""

import pandas as pd
import os
import numpy as np
import glob
import logging

os.chdir(r'/Volumes/Zihao_SSD2/PatentsView/')

# Configure logging
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

def load_patent_dict(path="patentsberta/patent_dict.csv"):
    dfpatent = pd.read_csv(
        path,
        usecols=["idx", "patent_id", "patent_year"],
        dtype={"idx": "int64", "patent_id": int, "year": "int16"},
    )
    dfpatent.columns = ["patent_idx", "patent_id", "patent_year"]
    if dfpatent.duplicated(["patent_idx", "patent_id"]).any():
        print("warning. duplicated [patent_idx, patent_id] found in patent_dict.csv!")
    dfpatent.drop_duplicates(["patent_idx", "patent_id"], inplace=True)
    return dfpatent

def merge_years(min_year: int, max_year: int, cutoff=0.75) -> pd.DataFrame:
    """
    merge all years of processed cosine similarity DataFrames
    Args:
        min_year: min patent year to merge
        max_year: max patent year to merge
    Returns:
        df ['patent_id', 'patent_year', 'cited_patent_id', 'cited_patent_year', 'sim_score', 'patent_idx', 'cited_patent_idx']
    """
    print('Loading patent_dict.csv...')
    dfpatent = load_patent_dict()
    df_cited_patent = dfpatent.copy()
    df_cited_patent.rename(columns={"patent_idx": "cited_patent_idx"}, inplace=True)
    
    yr_list = []
    for yr in range(min_year, max_year + 1):
        _dir = glob.glob(f"patentsberta/patent_sim_combined/sim_{yr}_cutoff/*.csv")
        print(f"start process and merge {len(_dir)} files for year {yr}")
        df_file = pd.concat(
            [
                pd.read_csv(
                    file,
                    dtype={
                        "patent_idx": "Int64",
                        "cited_patent_idx": "Int64",
                        "sim_score": "float32",
                    },
                )
                for file in _dir
            ],
            ignore_index=True,
        )

        df_file = df_file.dropna(subset=["cited_patent_idx"])
        df_file["cited_patent_idx"] = df_file["cited_patent_idx"].astype(int)
        df_file["patent_idx"] = df_file["patent_idx"].astype(int)

        x = len(df_file)
        df_file = df_file.merge(dfpatent, on=["patent_idx"]).merge(df_cited_patent, on=["cited_patent_idx"])

        print(len(df_file))
        assert x == len(df_file)
        df_file = df_file.sort_values(
            ["patent_idx", "sim_score", "cited_patent_idx"],
            ascending=[True, False, False]
        )
        df_file = df_file[df_file['sim_score'] >= cutoff]
        print(len(df_file))
        yr_list.append(df_file)
        del df_file, x
    
    df = pd.concat(yr_list, ignore_index=True)
    del yr_list
    df.sort_values(["patent_idx", "sim_score", "cited_patent_idx"], ascending=[True, False, False], inplace=True)
    df = df.rename(columns={
        'patent_id_x': 'patent_id',
        'patent_year_x': 'patent_year',
        'patent_id_y': 'cited_patent_id',
        'patent_year_y': 'cited_patent_year'
    })[['patent_id', 'patent_year', 'cited_patent_id', 'cited_patent_year', 'sim_score', 'patent_idx', 'cited_patent_idx']]
    
    return df

df_sim = merge_years(1981, 2015)
print("nrows of df_sim: ", len(df_sim))
df_sim.to_csv("cleandata/sim_score_1981_2015_cutoff.csv", index=False)